Skip to content

Propose to refactor output normalization in several transformers #11850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

tolgacangoz
Copy link
Contributor

@tolgacangoz tolgacangoz commented Jul 2, 2025

(I attempted to make replacements if you don't mind :)

This PR will be activated when the SkyReels-V2 models' integration PR is merged into main.

Replace FP32LayerNorm with AdaLayerNorm in the WanTransformer3DModel, WanVACETransformer3DModel, ..., simplifying the forward pass and enhancing model parallelism compatibility.

Context: #11518 (comment)

@yiyixuxu @a-r-r-o-w

Replace the final `FP32LayerNorm` and manual shift/scale application with a single `AdaLayerNorm` module in both the `WanTransformer3DModel` and `WanVACETransformer3DModel`.

This change simplifies the forward pass by encapsulating the adaptive normalization logic within the `AdaLayerNorm` layer, removing the need for a separate `scale_shift_table`. The `_no_split_modules` list is also updated to include `norm_out` for compatibility with model parallelism.
Updates the key mapping for the `head.modulation` layer to `norm_out.linear` in the model conversion script.

This correction ensures that weights are loaded correctly for both standard and VACE transformer models.
Replaces the manual implementation of adaptive layer normalization, which used a separate `scale_shift_table` and `nn.LayerNorm`, with the unified `AdaLayerNorm` module.

This change simplifies the forward pass logic in several transformer models by encapsulating the normalization and modulation steps into a single component. It also adds `norm_out` to `_no_split_modules` for model parallelism compatibility.
Corrects the target key for `head.modulation` to `norm_out.linear.weight`.

This ensures the weights are correctly mapped to the weight parameter of the output normalization layer during model conversion for both transformer types.
Adds a default zero-initialized bias tensor for the transformer's output normalization layer if it is missing from the original state dictionary.
@tolgacangoz tolgacangoz changed the title Refactor output normalization in several transformers Propose to refactor output normalization in several transformers Jul 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant